{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import numpy as np\n",
    "import random\n",
    "from tqdm.auto import tqdm\n",
    "import itertools\n",
    "import os\n",
    "from copy import deepcopy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compare(a,b):\n",
    "    a = int(a.strip(\"<>\"))\n",
    "    b = int(b.strip(\"<>\"))\n",
    "    if a<b:\n",
    "        return 0\n",
    "    if a==b:\n",
    "        return 1\n",
    "    if a>b:\n",
    "        return 2\n",
    "    assert False\n",
    "    \n",
    "def build_dicts(entities):\n",
    "    entity2ind = dict()\n",
    "    ind2entity = []\n",
    "    for i in range(len(entities)):\n",
    "        entity = entities[i]\n",
    "        if not (entity in ind2entity):\n",
    "            ind2entity.append(entity)\n",
    "            entity2ind[entity] = len(ind2entity) - 1\n",
    "    return ind2entity, entity2ind"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vocab = []\n",
    "\n",
    "num_entities = 1000\n",
    "entities = [\"<e_{}>\".format(i) for i in range(num_entities)]\n",
    "vocab = vocab + entities\n",
    "ind2entity, entity2ind = build_dicts(entities)\n",
    "\n",
    "num_attributes = 20\n",
    "attributes = [\"<attr_{}>\".format(i) for i in range(num_attributes)]\n",
    "vocab = vocab + attributes\n",
    "ind2attribute, attribute2ind = build_dicts(attributes)\n",
    "\n",
    "num_vals_per_attr = 20  # values range from [0, num_vals_per_attr-1]\n",
    "values = [\"<{}>\".format(i) for i in range(num_vals_per_attr)]\n",
    "vocab = vocab + values\n",
    "\n",
    "# randomly assign values to people's attributes\n",
    "atomic_KB = np.random.randint(low=0, high=num_vals_per_attr, size=(num_entities, num_attributes))     #  [entity id, attribute id] -> value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def rand_flip(tup):\n",
    "    tup_l = list(tup)\n",
    "    random.shuffle(tup_l)\n",
    "    return tuple(tup_l)\n",
    "    \n",
    "def choose(arr, ratio_or_count):\n",
    "    if type(ratio_or_count) == float:\n",
    "        num = round(ratio_or_count*len(arr))\n",
    "    elif type(ratio_or_count) == int:\n",
    "        num = ratio_or_count\n",
    "    else:\n",
    "         assert False\n",
    "    if num >= len(arr):\n",
    "        return arr\n",
    "    rand_inds = np.random.choice(len(arr), num, replace=False).tolist()\n",
    "    return [arr[i] for i in rand_inds]\n",
    "    \n",
    "def split(arr, ratio):\n",
    "    train, test = [], []\n",
    "    rand_inds = np.random.choice(len(arr), round(ratio*len(arr)), replace=False).tolist()\n",
    "    for i in range(len(arr)):\n",
    "        if i in rand_inds:\n",
    "            train.append(arr[i])\n",
    "        else:\n",
    "            test.append(arr[i])\n",
    "    return [train, test]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# special tokens\n",
    "vocab = vocab + [\"<mask>\", \"<sep>\", \"<a>\", \"</a>\", \"<q>\", \"</q>\"]\n",
    "\n",
    "comp_q_tokens = attributes\n",
    "comp2labels = dict()\n",
    "for comp_q_token in comp_q_tokens:\n",
    "    comp2labels[comp_q_token] = [\"<\"+comp_q_token.strip(\"<>\")+\"_{}>\".format(i) for i in range(3)]\n",
    "    vocab = vocab + comp2labels[comp_q_token]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert len(vocab) == len(set(vocab))\n",
    "print(\"vocab size:\", len(vocab))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def format_atomic(entity, attr, val, t):\n",
    "    val = \"<{}>\".format(val)\n",
    "    input_text = \"\".join([entity, attr])\n",
    "    target_text = input_text + \"\".join([val, \"</a>\"])\n",
    "    return {\n",
    "        \"input_text\": input_text,\n",
    "        \"target_text\": target_text,\n",
    "        \"type\": t,\n",
    "    }\n",
    "\n",
    "def format_comp(comp_q_token, ent_1, ent_2, label, t):\n",
    "    input_text = \"\".join([comp_q_token, \"<q>\", ent_1, \"<mask>\", ent_2])\n",
    "    target_text = input_text + \"\".join([label, \"</a>\"])\n",
    "    return {\n",
    "        \"input_text\": input_text,\n",
    "        \"target_text\": target_text,\n",
    "        \"type\": t,\n",
    "    }\n",
    "\n",
    "num_id_entities_ratio = 0.9\n",
    "\n",
    "id_atomic_facts, ood_atomic_facts = [], []\n",
    "train_inferred, test_inferred_iid, test_inferred_ood = [], [], []\n",
    "\n",
    "def compare_ent(ent_1, ent_2, attr):\n",
    "    val_1, val_2 = atomic_KB[entity2ind[ent_1], attribute2ind[attr]], atomic_KB[entity2ind[ent_2], attribute2ind[attr]]\n",
    "    return compare(\"<{}>\".format(val_1), \"<{}>\".format(val_2))\n",
    "\n",
    "for comp_q_token in tqdm(comp_q_tokens):\n",
    "    id_entities, ood_entities = split(entities, num_id_entities_ratio)\n",
    "    id_entities, ood_entities = set(id_entities), set(ood_entities)\n",
    "\n",
    "    for entity in id_entities:\n",
    "        val = atomic_KB[entity2ind[entity], attribute2ind[comp_q_token]]\n",
    "        id_atomic_facts.append(format_atomic(entity, comp_q_token, val, t='id_atomic'))\n",
    "\n",
    "    for entity in ood_entities:\n",
    "        val = atomic_KB[entity2ind[entity], attribute2ind[comp_q_token]]\n",
    "        ood_atomic_facts.append(format_atomic(entity, comp_q_token, val, t='ood_atomic'))\n",
    "    \n",
    "    all_pairs = list(itertools.combinations(entities, 2))\n",
    "    for (ent_1, ent_2) in all_pairs:\n",
    "        if ent_1 in ood_entities and ent_2 in ood_entities:\n",
    "            ty = 'test_inferred_ood'\n",
    "            label = comp2labels[comp_q_token][compare_ent(ent_1, ent_2, comp_q_token)]\n",
    "            test_inferred_ood.append(format_comp(comp_q_token, ent_1, ent_2, label, t=ty))\n",
    "            # flip\n",
    "            label = comp2labels[comp_q_token][compare_ent(ent_2, ent_1, comp_q_token)]\n",
    "            test_inferred_ood.append(format_comp(comp_q_token, ent_2, ent_1, label, t=ty))\n",
    "        elif ent_1 in id_entities and ent_2 in id_entities:\n",
    "            if np.random.uniform() < 0.1:\n",
    "                ty = 'test_inferred_iid'\n",
    "                label = comp2labels[comp_q_token][compare_ent(ent_1, ent_2, comp_q_token)]\n",
    "                test_inferred_iid.append(format_comp(comp_q_token, ent_1, ent_2, label, t=ty))\n",
    "                # flip\n",
    "                label = comp2labels[comp_q_token][compare_ent(ent_2, ent_1, comp_q_token)]\n",
    "                test_inferred_iid.append(format_comp(comp_q_token, ent_2, ent_1, label, t=ty))\n",
    "            else:\n",
    "                ty = 'train_inferred'\n",
    "                label = comp2labels[comp_q_token][compare_ent(ent_1, ent_2, comp_q_token)]\n",
    "                train_inferred.append(format_comp(comp_q_token, ent_1, ent_2, label, t=ty))\n",
    "                # flip\n",
    "                label = comp2labels[comp_q_token][compare_ent(ent_2, ent_1, comp_q_token)]\n",
    "                train_inferred.append(format_comp(comp_q_token, ent_2, ent_1, label, t=ty))\n",
    "        else:\n",
    "            pass\n",
    "\n",
    "print(len(id_atomic_facts), len(ood_atomic_facts), \"|\", len(train_inferred), len(test_inferred_iid)), len(test_inferred_ood)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_size = 3000\n",
    "comp_facts_test_ds = choose(test_inferred_ood, test_size)\n",
    "\n",
    "probes = []\n",
    "probes = probes + comp_facts_test_ds\n",
    "probes = probes + choose(id_atomic_facts, test_size)\n",
    "probes = probes + choose(ood_atomic_facts, test_size)\n",
    "probes = probes + choose(test_inferred_iid, test_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# downsampling inferred facts included in training\n",
    "for inf_atom_ratio in [12.6,9.0,7.2,3.6]:\n",
    "    dataset_name = \"comparison.{}.{}\".format(num_entities, inf_atom_ratio)\n",
    "    os.makedirs(\"data/{}\".format(dataset_name), exist_ok=True)\n",
    "\n",
    "    train_inferred_ds = choose(train_inferred, round(inf_atom_ratio*len(id_atomic_facts)))\n",
    "\n",
    "    probes_ = probes + choose(train_inferred_ds, test_size)\n",
    "\n",
    "    print(\"train/test atomic, # train inferred:\", len(id_atomic_facts), len(ood_atomic_facts), len(train_inferred_ds))\n",
    "    with open(\"data/{}/train.json\".format(dataset_name), \"w\", encoding='utf-8') as f:\n",
    "        json.dump(id_atomic_facts + ood_atomic_facts + train_inferred_ds, f)\n",
    "    with open(\"data/{}/valid.json\".format(dataset_name), \"w\", encoding='utf-8') as f:\n",
    "        json.dump(comp_facts_test_ds, f)\n",
    "    with open(\"data/{}/test.json\".format(dataset_name), \"w\", encoding='utf-8') as f:\n",
    "        json.dump(probes_, f)\n",
    "    # add vocab\n",
    "    with open(\"data/{}/vocab.json\".format(dataset_name), \"w\", encoding='utf-8') as f:\n",
    "        json.dump(vocab, f)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "CLM",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
